{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Multitrack MusicVAE.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QI5g-x4foZls",
        "colab_type": "text"
      },
      "source": [
        "# Multitrack MusicVAE: Learning a Latent Space of Multitrack Measures\n",
        "### ___Ian Simon, Adam Roberts, Colin Raffel, Jesse Engel, Curtis Hawthorne, Douglas Eck___\n",
        "\n",
        "[MusicVAE](https://g.co/magenta/music-vae) learns a latent space of musical sequences.  Here we apply the MusicVAE framework to single measures of multi-instrument General MIDI, a symbolic music representation that uses a standard set of 128 instrument sounds.\n",
        "\n",
        "The models in this notebook are capable of encoding and decoding single measures of up to 8 tracks, optionally conditioned on an underlying chord.  Encoding transforms a single measure into a vector in a latent space, and decoding transforms a latent vector back into a measure.  Both encoding and decoding are performed hierarchically, with one level operating on tracks and another operating on the notes (and choice of instrument) in each track.\n",
        "\n",
        "See our [arXiv paper](https://arxiv.org/abs/1806.00195) for more details, along with our [blog post](http://g.co/magenta/multitrack) with links to JavaScript CodePens."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rDMQbHPYVKmV",
        "colab_type": "text"
      },
      "source": [
        "# Environment Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tciXVi5eWG_1",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Setup Environment\n",
        "\n",
        "print('Copying checkpoints and modified SGM SoundFont (https://sites.google.com/site/soundfonts4u) from GCS.')\n",
        "print('This will take a few minutes...')\n",
        "!gsutil -q -m cp gs://download.magenta.tensorflow.org/models/music_vae/multitrack/* /content/\n",
        "!gsutil -q -m cp gs://download.magenta.tensorflow.org/soundfonts/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2 /content/\n",
        "\n",
        "print('Installing dependencies...')\n",
        "!apt-get update -qq && apt-get install -qq libfluidsynth2 build-essential libasound2-dev libjack-dev\n",
        "!pip install -qU magenta pyfluidsynth pretty_midi\n",
        "\n",
        "print('Importing libraries...')\n",
        "\n",
        "import numpy as np\n",
        "import os\n",
        "import tensorflow.compat.v1 as tf\n",
        "\n",
        "from google.colab import files\n",
        "\n",
        "import magenta.music as mm\n",
        "from magenta.music.sequences_lib import concatenate_sequences\n",
        "from magenta.models.music_vae import configs\n",
        "from magenta.models.music_vae.trained_model import TrainedModel\n",
        "\n",
        "tf.disable_v2_behavior()\n",
        "print('Done!')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3URxzTQyXfdO",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Definitions\n",
        "\n",
        "BATCH_SIZE = 4\n",
        "Z_SIZE = 512\n",
        "TOTAL_STEPS = 512\n",
        "BAR_SECONDS = 2.0\n",
        "CHORD_DEPTH = 49\n",
        "\n",
        "SAMPLE_RATE = 44100\n",
        "SF2_PATH = '/content/SGM-v2.01-Sal-Guit-Bass-V1.3.sf2'\n",
        "\n",
        "# Play sequence using SoundFont.\n",
        "def play(note_sequences):\n",
        "  if not isinstance(note_sequences, list):\n",
        "    note_sequences = [note_sequences]\n",
        "  for ns in note_sequences:\n",
        "    mm.play_sequence(ns, synth=mm.fluidsynth, sf2_path=SF2_PATH)\n",
        "  \n",
        "# Spherical linear interpolation.\n",
        "def slerp(p0, p1, t):\n",
        "  \"\"\"Spherical linear interpolation.\"\"\"\n",
        "  omega = np.arccos(np.dot(np.squeeze(p0/np.linalg.norm(p0)), np.squeeze(p1/np.linalg.norm(p1))))\n",
        "  so = np.sin(omega)\n",
        "  return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1\n",
        "\n",
        "# Download sequence.\n",
        "def download(note_sequence, filename):\n",
        "  mm.sequence_proto_to_midi_file(note_sequence, filename)\n",
        "  files.download(filename)\n",
        "\n",
        "# Chord encoding tensor.\n",
        "def chord_encoding(chord):\n",
        "  index = mm.TriadChordOneHotEncoding().encode_event(chord)\n",
        "  c = np.zeros([TOTAL_STEPS, CHORD_DEPTH])\n",
        "  c[0,0] = 1.0\n",
        "  c[1:,index] = 1.0\n",
        "  return c\n",
        "\n",
        "# Trim sequences to exactly one bar.\n",
        "def trim_sequences(seqs, num_seconds=BAR_SECONDS):\n",
        "  for i in range(len(seqs)):\n",
        "    seqs[i] = mm.extract_subsequence(seqs[i], 0.0, num_seconds)\n",
        "    seqs[i].total_time = num_seconds\n",
        "\n",
        "# Consolidate instrument numbers by MIDI program.\n",
        "def fix_instruments_for_concatenation(note_sequences):\n",
        "  instruments = {}\n",
        "  for i in range(len(note_sequences)):\n",
        "    for note in note_sequences[i].notes:\n",
        "      if not note.is_drum:\n",
        "        if note.program not in instruments:\n",
        "          if len(instruments) >= 8:\n",
        "            instruments[note.program] = len(instruments) + 2\n",
        "          else:\n",
        "            instruments[note.program] = len(instruments) + 1\n",
        "        note.instrument = instruments[note.program]\n",
        "      else:\n",
        "        note.instrument = 9\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pl3oY0w8gBJh",
        "colab_type": "text"
      },
      "source": [
        "# Chord-Conditioned Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "n8z2zwC9gF2i",
        "colab_type": "code",
        "cellView": "both",
        "colab": {}
      },
      "source": [
        "#@title Load Checkpoint\n",
        "\n",
        "config = configs.CONFIG_MAP['hier-multiperf_vel_1bar_med_chords']\n",
        "model = TrainedModel(\n",
        "    config, batch_size=BATCH_SIZE,\n",
        "    checkpoint_dir_or_path='/content/model_chords_fb64.ckpt')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_1ybYgKSgIt-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Same Chord, Random Styles\n",
        "\n",
        "chord = 'C' #@param {type:\"string\"}\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "seqs = model.sample(n=BATCH_SIZE, length=TOTAL_STEPS, temperature=temperature,\n",
        "                    c_input=chord_encoding(chord))\n",
        "\n",
        "trim_sequences(seqs)\n",
        "play(seqs)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qSmhyk6PgQbI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Same Style, Chord Progression\n",
        "\n",
        "chord_1 = 'C' #@param {type:\"string\"}\n",
        "chord_2 = 'Caug' #@param {type:\"string\"}\n",
        "chord_3 = 'Am' #@param {type:\"string\"}\n",
        "chord_4 = 'E' #@param {type:\"string\"}\n",
        "chords = [chord_1, chord_2, chord_3, chord_4]\n",
        "\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "z = np.random.normal(size=[1, Z_SIZE])\n",
        "seqs = [\n",
        "    model.decode(length=TOTAL_STEPS, z=z, temperature=temperature,\n",
        "                 c_input=chord_encoding(c))[0]\n",
        "    for c in chords\n",
        "]\n",
        "\n",
        "trim_sequences(seqs)\n",
        "fix_instruments_for_concatenation(seqs)\n",
        "prog_ns = concatenate_sequences(seqs)\n",
        "\n",
        "play(prog_ns)\n",
        "mm.plot_sequence(prog_ns)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xwQyOx8E4NLK",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Save Arrangement to MIDI\n",
        "download(prog_ns, '_'.join(chords) + '.mid')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "2bXqnFz9gAya",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Style Interpolation, Repeating Chord Progression\n",
        "\n",
        "chord_1 = 'Dm' #@param {type:\"string\"}\n",
        "chord_2 = 'F' #@param {type:\"string\"}\n",
        "chord_3 = 'Am' #@param {type:\"string\"}\n",
        "chord_4 = 'G' #@param {type:\"string\"}\n",
        "chords = [chord_1, chord_2, chord_3, chord_4]\n",
        "\n",
        "num_bars = 32 #@param {type:\"slider\", min:4, max:64, step:4}\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "\n",
        "z1 = np.random.normal(size=[Z_SIZE])\n",
        "z2 = np.random.normal(size=[Z_SIZE])\n",
        "z = np.array([slerp(z1, z2, t)\n",
        "              for t in np.linspace(0, 1, num_bars)])\n",
        "\n",
        "seqs = [\n",
        "    model.decode(length=TOTAL_STEPS, z=z[i:i+1, :], temperature=temperature,\n",
        "                 c_input=chord_encoding(chords[i % 4]))[0]\n",
        "    for i in range(num_bars)\n",
        "]\n",
        "\n",
        "trim_sequences(seqs)\n",
        "fix_instruments_for_concatenation(seqs)\n",
        "prog_interp_ns = concatenate_sequences(seqs)\n",
        "\n",
        "play(prog_interp_ns)\n",
        "mm.plot_sequence(prog_interp_ns)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "AsPJX7Lzfek5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Save to MIDI\n",
        "download(prog_interp_ns, 'interp_' + '_'.join(chords) + '.mid')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "loV8bwJ8fOR_",
        "colab_type": "text"
      },
      "source": [
        "# Unconditioned Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cao1ezEDfRQG",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Load Checkpoint\n",
        "\n",
        "config = configs.CONFIG_MAP['hier-multiperf_vel_1bar_med']\n",
        "model = TrainedModel(\n",
        "    config, batch_size=BATCH_SIZE,\n",
        "    checkpoint_dir_or_path='/content/model_fb256.ckpt')\n",
        "model._config.data_converter._max_tensors_per_input = None"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WnAZsIYsfXWV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Random Samples\n",
        "\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "seqs = model.sample(n=BATCH_SIZE, length=TOTAL_STEPS, temperature=temperature)\n",
        "\n",
        "trim_sequences(seqs)\n",
        "play(seqs)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "glPwRID0fsC-",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Interpolation Between Random Samples\n",
        "\n",
        "num_bars = 32 #@param {type:\"slider\", min:4, max:64, step:1}\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "\n",
        "z1 = np.random.normal(size=[Z_SIZE])\n",
        "z2 = np.random.normal(size=[Z_SIZE])\n",
        "z = np.array([slerp(z1, z2, t)\n",
        "              for t in np.linspace(0, 1, num_bars)])\n",
        "\n",
        "seqs = model.decode(length=TOTAL_STEPS, z=z, temperature=temperature)\n",
        "\n",
        "trim_sequences(seqs)\n",
        "fix_instruments_for_concatenation(seqs)\n",
        "interp_ns = concatenate_sequences(seqs)\n",
        "\n",
        "play(interp_ns)\n",
        "mm.plot_sequence(interp_ns)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yyPUVh2cyZF2",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Save to MIDI\n",
        "download(interp_ns, 'interp.mid')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TAs8Stysy640",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Upload MIDI Files to Reconstruct\n",
        "midi_files = files.upload().values()\n",
        "seqs = [mm.midi_to_sequence_proto(midi) for midi in midi_files]\n",
        "\n",
        "uploaded_seqs = []\n",
        "for seq in seqs:\n",
        "  _, tensors, _, _ = model._config.data_converter.to_tensors(seq)\n",
        "  uploaded_seqs.extend(model._config.data_converter.from_tensors(tensors))\n",
        "  \n",
        "trim_sequences(uploaded_seqs)\n",
        "\n",
        "print('Parsed %d measures' % len(uploaded_seqs))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ka9VaIh31rsE",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Encode and Decode\n",
        "\n",
        "index = 0 #@param {type:\"integer\"}\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "\n",
        "z, _, _ = model.encode([uploaded_seqs[index]])\n",
        "reconstructed_seq = model.decode(z, length=TOTAL_STEPS,\n",
        "                                 temperature=temperature)[0]\n",
        "\n",
        "trim_sequences([reconstructed_seq])\n",
        "\n",
        "print('Original')\n",
        "play(uploaded_seqs[index])\n",
        "mm.plot_sequence(uploaded_seqs[index])\n",
        "\n",
        "print('Reconstructed')\n",
        "play(reconstructed_seq)\n",
        "mm.plot_sequence(reconstructed_seq)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "PvDLoDwPgeNV",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title Interpolation Between Encodings\n",
        "\n",
        "index_1 = 0 #@param {type:\"integer\"}\n",
        "index_2 = 1 #@param {type:\"integer\"}\n",
        "\n",
        "num_bars = 32 #@param {type:\"slider\", min:4, max:64, step:4}\n",
        "temperature = 0.2 #@param {type:\"slider\", min:0.01, max:1.5, step:0.01}\n",
        "\n",
        "z1, _, _ = model.encode([uploaded_seqs[index_1]])\n",
        "z2, _, _ = model.encode([uploaded_seqs[index_2]])\n",
        "z = np.array([slerp(np.squeeze(z1), np.squeeze(z2), t)\n",
        "              for t in np.linspace(0, 1, num_bars)])\n",
        "\n",
        "seqs = model.decode(length=TOTAL_STEPS, z=z, temperature=temperature)\n",
        "\n",
        "trim_sequences(seqs)\n",
        "fix_instruments_for_concatenation(seqs)\n",
        "recon_interp_ns = concatenate_sequences(seqs)\n",
        "\n",
        "play(recon_interp_ns)\n",
        "mm.plot_sequence(recon_interp_ns)\n"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G7eLLp3q5WCB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Save to MIDI\n",
        "download(recon_interp_ns, 'recon_interp.mid')"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}
